Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dmoe integration #1210

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Dmoe integration #1210

wants to merge 14 commits into from

Conversation

DayOfThePenguin
Copy link
Contributor

Supersedes #1197

This PR adds dropless MoE support using the Grouped GEMM implementation in megablocks.

Features

Unlike the legacy DeepSpeed MoE implementation that uses the data parallel groups for expert parallelism, this implementation uses the model parallel group to parallelize the experts. This avoids the following problems:

  • Using data parallel groups to distribute the experts will incur inter-node communications to do a forward pass through a single layer
  • MoE + pipeline parallelism is very complicated to reason about when you have MoE weights distributed across data parallel groups & deepspeed doesn't natively support it.

Clarified arguments by separating MoE args into their own class.

Use sinkhorn routing by default, support k>=1. TopK routing is used for evaluation/inference.

Testing

Tested PP [3, 2, 1] and MP [1, 2, 4, 8] on Ampere GPUs.

Notes

Added megablocks and grouped_gemm to the dependencies. It might be desirable to pull some of the kernels in directly like in NVIDIA megatron-core.

@Quentin-Anthony Quentin-Anthony added the merge-queue This PR is next on the queue to merge label Dec 3, 2024
Quentin-Anthony and others added 2 commits December 3, 2024 15:53
- Removed mp assertion for moe
- Removed mlp_type checks in moe code
- Added Bf16 conversion to dmoe_gather
@@ -185,9 +180,102 @@ def _dmoe_gather(input_: torch.Tensor, tokens_per_expert: torch.Tensor):
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=gather_dim)

# Bf16 convert
Copy link
Contributor

@aurelion-source aurelion-source Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this results in fp32 output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was resolved in the latest commit.

@aurelion-source
Copy link
Contributor

Profiles before and after the merge:
https://wandb.ai/shetano-personal/dmoe/reports/DMOE--VmlldzoxMDU0NDY0OQ

@@ -67,35 +58,26 @@

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"weight_decay": 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there appear to be a lot of extraneous config changes. Any reason why?

@@ -1075,15 +1077,6 @@ def calculate_derived(self):
# if we set pipe_parallel_size to 0, GPT2ModelPipe.to_sequential() is called, and we run training with
# the sequential model without the PipelineModule wrapper to avoid the overhead it incurs
self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1)
if self.moe_num_experts > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we test these parallelism combinations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge-queue This PR is next on the queue to merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants